
# ## Install devtools if not already installed
# install.packages("devtools", repos='http://cran.us.r-project.org')
# library(devtools)
# ## Install augsynth from github
# devtools::install_github("ebenmichael/augsynth", force = T)
# library(augsynth)
## general libraries
library(dplyr)
library(tidyverse)
library(ggplot2)
library(MASS)
rm(list = ls())
# The palette with black:
cbp1 <- c("#999999", "#E69F00", "#56B4E9", "#009E73", "#F0E442", "#0072B2", "#D55E00", "#CC79A7")
# Change the weights in model to reflect importance of common factors
rho <- 1
bias_fig_name <- 'figure/all_bias_sims_n50_rho1.pdf'
imbalance_fig_name <- 'figure/all_imbalance_sims_n50_rho1.pdf'

# rho <- 0.5
# bias_fig_name <- 'figure/all_bias_sims_n50_rho05.pdf'
# imbalance_fig_name <- 'figure/all_imbalance_sims_n50_rho05.pdf'

# rho <- 0
# bias_fig_name <- 'figure/all_bias_sims_n50_rho0.pdf'
# imbalance_fig_name <- 'figure/all_imbalance_sims_n50_rho0.pdf'

##### SCM function -------------------------------------

#' Solve the synth QP directly
#' @param X1 Target vector
#' @param X0 Matrix of control outcomes
#' @noRd
synth_qp <- function(X1, X0) {
  
  Pmat <- X0 %*% t(X0)
  qvec <- - t(X1) %*% t(X0)
  
  n0 <- nrow(X0)
  A <- rbind(rep(1, n0), diag(n0))
  l <- c(1, numeric(n0))
  u <- c(1, rep(1, n0))
  
  settings = osqp::osqpSettings(verbose = FALSE,
                                eps_rel = 1e-8,
                                eps_abs = 1e-8)
  sol <- osqp::solve_osqp(P = Pmat, q = qvec,
                          A = A, l = l, u = u, 
                          pars = settings)
  
  return(sol$x)
}

##### Bias 
SCMbias <- function(Y0,Y1,weights) {
  synthY0 <- (Y0%*%weights)
  gap <- Y1-synthY0 
  return(gap)
}

#### RMSE
rmse <- function(Y0,Y1,weights) {
  synthY0 <- (Y0%*%weights)
  gap <- Y1-synthY0 
  return(sqrt(mean(gap^2)))
}

#' Given the factors and factor loadings, simulate the data
#' @param factors Factors
#' @param mu Factor loadings
#' @param n number of units
#' @param trt treatment indicator vector
#' @param k_total number of outcomes
#' @param t_total number of pre-treatment periods
#' @param variance variance of the noise term
factor_sim <- function(model,n,trt,k_total,t_total,variance) {
  
  ## generate factor model data
  
  epsilon <- matrix(rnorm((n * t_total) * k_total, sd = sqrt(variance)), ncol = k_total)
  # out <-  factors[1:t_total,]%x%mu  + epsilon # (NXT) x (K) matrix of outcomes
  # out <- as.vector(factors[1:t_total]%x%mu) + epsilon # (NXT) x (K) matrix of outcomes
  out <- model[1:(n * t_total),] + epsilon # (NXT) x (K) matrix of outcomes
  t_trt <- rep(trt,t_total) # select the periods for the treated unit
  out_trt <- matrix(out[which(t_trt==1),],nrow = t_total,ncol =k_total)  # T x K matrix of outcomes for the treated unit
  out_control <- out[which(t_trt==0),] # ((N-1)xT) x (K) matrix of outcomes for the control units
  ## calculate SCM weights
  out_trt_sep <- out_trt[,1]
  out_control_sep <- matrix(out_control[,1], nrow = n-1,ncol=t_total)
  w_sep <- synth_qp(out_trt_sep, out_control_sep);
  r.svd <- svd(rbind(out_control_sep,out_trt_sep))
  largest_svd <- r.svd$d[1]^2/sum(r.svd$d^2)
  cond <-  r.svd$d[1]/r.svd$d[t_total]
 
  
  out_trt_cat <- matrix(out_trt, nrow = t_total*k_total, ncol = 1)
  out_control_cat <- matrix(out_control, nrow = n-1,ncol=t_total*k_total)
  w_cat <- synth_qp(out_trt_cat, out_control_cat);
  r.svd <- svd(rbind(out_control_cat,t(out_trt_cat)))
  largest_svd <- cbind(largest_svd, r.svd$d[1]^2/sum(r.svd$d^2))
  cond <- cbind(cond,  r.svd$d[1]/r.svd$d[length(r.svd$d)])
  
  out_trt_avg <- rowMeans(out_trt)
  out_control_avg <- matrix(rowMeans(out_control), nrow = n-1,ncol=t_total)
  w_avg <- synth_qp(out_trt_avg, out_control_avg);
  r.svd <- svd(rbind(out_control_avg,t(out_trt_avg)))
  largest_svd <- cbind(largest_svd, r.svd$d[1]^2/sum(r.svd$d^2))
  cond <- cbind(cond,  r.svd$d[1]/r.svd$d[t_total])
 
  
  ## calculate bias
  model_t1 <- model[((n * t_total)+1):(n * (t_total+1)),]
  oracle_bias_sep <- SCMbias(model_t1[-trt,1],model_t1[trt,1],w_sep)
  oracle_bias_cat <- SCMbias(model_t1[-trt,1],model_t1[trt,1],w_cat)
  oracle_bias_avg <- SCMbias(model_t1[-trt,1],model_t1[trt,1],w_avg)
  # oracle_bias_sep <- SCMbias(mu[-trt]*factors[t_total+1,1],mu[trt]*factors[t_total+1,1],w_sep)
  # oracle_bias_cat <- SCMbias(mu[-trt]*factors[t_total+1,1],mu[trt]*factors[t_total+1,1],w_cat)
  # oracle_bias_avg <- SCMbias(mu[-trt]*factors[t_total+1,1],mu[trt]*factors[t_total+1,1],w_avg)
  # oracle_bias_sep <- SCMbias(mu[-trt]*factors[t_total+1],mu[trt]*factors[t_total+1],w_sep)
  # oracle_bias_cat <- SCMbias(mu[-trt]*factors[t_total+1],mu[trt]*factors[t_total+1],w_cat)
  # oracle_bias_avg <- SCMbias(mu[-trt]*factors[t_total+1],mu[trt]*factors[t_total+1],w_avg)
  
  ## calculate pre-treatment imbalance
  imbalance_sep <- rmse(t(out_control_sep),out_trt_sep,w_sep)
  imbalance_cat <- rmse(t(out_control_cat),out_trt_cat,w_cat)
  imbalance_avg <- rmse(t(out_control_avg),out_trt_avg,w_avg)
  
  return(list("oracle_bias" = c(oracle_bias_sep,oracle_bias_cat,oracle_bias_avg),
              "imbalance" = c(imbalance_sep,imbalance_cat,imbalance_avg),
              "largest_svd"=largest_svd,"cond"=cond ) )
}


# factor loadings
n <- 50 # number of units, the first unit is treated
# mu <- c(2,1,3:n); # hand-code factor loadings for n = 5
mu <- list(
  c = seq(5, 1, length.out = n) # N x 1 vector
)
mu$c[c(1, 2)] <- mu$c[c(2, 1)]; # hand-code factor loadings for n = 50 (common part)

trt <- numeric(n); trt[1] <- 1; # select the unit with the second largest loadings to be the treated unit

settings_list <- list(
  list(t_total=10, k_total=4, file_suffix="T_0=10,K=4"),
  list(t_total=10, k_total=10, file_suffix="T_0=10,K=10"),
  list(t_total=40, k_total=4, file_suffix="T_0=40,K=4"),
  list(t_total=40, k_total=10, file_suffix="T_0=40,K=10")
)     # incorporate all settings into a loop

all_bias_sim_long <- data.frame()
all_imbalance_sim_long <- data.frame()
all_largest_svd_sim_long <- data.frame()
all_cond_sim_long <- data.frame()

for (setting in settings_list) {
  t_total <- setting$t_total
  k_total <- setting$k_total   
  
  factors <- list(
    c = seq(0.5, 1, length.out = t_total+1) # hand-code factor (common part)
  )
  factors$i <- matrix(0,nrow=t_total+1, ncol = k_total) # outcome idiosyncratic part of factor
  set.seed(1)
  for ( k in 1:k_total ) {
    series <- arima.sim(model = list(ar = 0.5), n = t_total+1) # independent AR series across outcomes
    factors$i[,k] <- 0.5 + (series - min(series)) * (1 - 0.5) / (max(series) - min(series))  # rescale so factors are in same range
    factors$i[,k] <- sort(factors$i[,k])
  }
  
  set.seed(1)
  mu$i <- matrix(0,nrow=n, ncol = k_total) # outcome idiosyncratic part of factor
  for (k in (1:k_total)) {
    mu$i[,k] <- rnorm(n)
    mu$i[,k] <- 1 + mu$i[,k] * (5 - 1) / (max(mu$i[,k] ) - min(mu$i[,k] )) # rescale so bias is in same range
  }
  cor(mu$i[,1],mu$c)
  cor(mu$i[,1],mu$i[,2])
  
  theo_w <- synth_qp(mu$c[trt], as.matrix(mu$c[-trt])); # confirm we can find the oracle weights
  SCMbias(mu$c[-trt]*factors$c[t_total+1], mu$c[trt]*factors$c[t_total+1],theo_w)
  
   
  mu$i[trt,] <- t(theo_w)%*%mu$i[-trt,] # overwrite the idiosyncratic loadings to ensure oracle weights exist
  
  theo_w <- synth_qp(c(mu$c[trt],mu$i[trt,]), as.matrix(cbind(mu$c[-trt],mu$i[-trt,]))); # confirm we can find the oracle weights
  SCMbias(mu$c[-trt]*factors$c[t_total+1] + mu$i[-trt,1]*factors$i[t_total+1,1],
          mu$c[trt]*factors$c[t_total+1] + mu$i[trt,1]*factors$i[t_total+1,1], 
          theo_w)
  
  # create the model components
  factors$i[,1] <- factors$c
  # Repeat each row of factors N times (expand to NT x K)
  factors_expanded <- factors$i[rep(1:(t_total+1), each = n), ]
  mu$i[,1] <- mu$c
  # Repeat mu T times (expand to NT x K)
  mu_expanded <- mu$i[rep(1:n, (t_total+1)), ]
  
  # Element-wise multiplication
  model <- rho * as.vector(factors$c %x% mu$c) + (1-rho) * factors_expanded * mu_expanded
  # simulation
  
  
  n_sim <- 1000
  columns <- c("sep","cat","avg") 
  bias_sim <- data.frame(matrix(nrow = 0, ncol = length(columns))) 
  colnames(bias_sim) <- columns
  
  imbalance_sim <- data.frame(matrix(nrow = 0, ncol = length(columns))) 
  colnames(imbalance_sim) <- columns
  
  largest_svd_sim <- data.frame(matrix(nrow = 0, ncol = length(columns))) 
  colnames(largest_svd_sim) <- columns 
  cond_sim <- data.frame(matrix(nrow = 0, ncol = length(columns))) 
  colnames(cond_sim) <- columns 
  
  for (s in 1:n_sim) {
    set.seed(s)
    result <- factor_sim(model,n,trt,k_total,t_total,variance=1)
    # result <- factor_sim(mu,factors_tk,n,trt,k_total,t_total,variance=1)
    bias_sim[s,] <- result$oracle_bias
    imbalance_sim[s,] <- result$imbalance
    largest_svd_sim[s,] <- result$largest_svd
    cond_sim[s,] <- result$cond
  }
  
  # check the results
  summary(bias_sim)
  summary(imbalance_sim)
  colMeans(largest_svd_sim)
  colMeans(cond_sim)
  
  
  bias_sim_long <- bias_sim %>% 
    pivot_longer(columns, names_to = "method", values_to = "bias") %>% mutate(setting = setting$file_suffix)
  imbalance_sim_long <- imbalance_sim %>% 
    pivot_longer(columns, names_to = "method", values_to = "imbalance") %>% mutate(setting = setting$file_suffix)
  
  all_bias_sim_long <- rbind(all_bias_sim_long, bias_sim_long)
  all_imbalance_sim_long <- rbind(all_imbalance_sim_long, imbalance_sim_long)
  
  largest_svd_sim_long <- largest_svd_sim %>% 
    pivot_longer(columns, names_to = "method", values_to = "svd") %>% mutate(setting = setting$file_suffix)
  cond_sim_long <- cond_sim %>% 
    pivot_longer(columns, names_to = "method", values_to = "cond") %>% mutate(setting = setting$file_suffix)
  
  all_largest_svd_sim_long <- rbind(all_largest_svd_sim_long, largest_svd_sim_long)
  all_cond_sim_long <- rbind(all_cond_sim_long, cond_sim_long)
}

# Check the factor structure based on the share of variation explained and the conditional numbers
all_largest_svd_sim_long %>% filter(setting == "T_0=40,K=10") %>% group_by(method) %>%
  summarize(mean_x = mean(svd))

ratio <- all_cond_sim_long %>% filter(setting == "T_0=40,K=10" & method == "avg") %>% dplyr::select(cond) 
ratio <- ratio / all_cond_sim_long %>% filter( setting == "T_0=40,K=10" & method == "sep") %>% dplyr::select(cond) 
avg_cond_avg_v_sep <- mean(t(ratio)-1)
cat("Average ratio of conditional number for average over separate SCM-1",avg_cond_avg_v_sep,"\n" )
# Adjust the order of the setting
all_bias_sim_long$setting <- factor(all_bias_sim_long$setting, levels = c("T_0=10,K=4", "T_0=10,K=10", "T_0=40,K=4", "T_0=40,K=10"))
all_imbalance_sim_long$setting <- factor(all_imbalance_sim_long$setting, levels = c("T_0=10,K=4", "T_0=10,K=10", "T_0=40,K=4" , "T_0=40,K=10"))


# pdf('figure/all_bias_sims_n50_p05.pdf') #n09 = -0.9 AR; p09 = +0.9
pdf(bias_fig_name)
all_bias_sim_long %>%
  ggplot(aes(x=method, y=bias, fill=method)) +
  geom_boxplot(notch=FALSE,outlier.shape=NA) +
  stat_summary(fun=mean, geom="point", shape=20, size=2, color="black", fill="black") +
  facet_wrap(~setting, scales = "fixed", nrow = 2) + # Set scales as fixed
  geom_hline(yintercept = 0, color = "black") + # Add a horizontal line at 0 to all facets
  ylim(-0.25,2) + 
  labs(y = "Bias", x = "SC Method") +
  theme_minimal() +
  theme(axis.text.x = element_text(angle = 45, hjust = 1, size=12),
        axis.text.y = element_text(size=12),
        axis.title.x = element_text(size=14, face="bold"),
        axis.title.y = element_text(size=14, face="bold"),
        plot.title = element_text(size=16, face="bold", hjust=0.5),
        plot.margin = unit(c(0,0,0,0), "lines"),
        aspect.ratio = 3/4) +
  scale_fill_manual(values=cbp1)
dev.off()

pdf(imbalance_fig_name)
all_imbalance_sim_long %>%
  ggplot(aes(x=method, y=imbalance, fill=method)) +
  geom_boxplot(notch=FALSE,outlier.shape=NA) +
  stat_summary(fun=mean, geom="point", shape=20, size=2, color="black", fill="black") +
  facet_wrap(~setting, scales = "fixed", nrow = 2) + 
  geom_hline(yintercept = 0, color = "black") + # Add a horizontal line at 0 to all facets
  ylim(0,1.75) + 
  labs(y = "Imbalance", x = "SC Method") +
  theme_minimal() +
  theme(axis.text.x = element_text(angle = 45, hjust = 1, size =12),
        axis.text.y = element_text(size=12),
        axis.title.x = element_text(size=14, face="bold"),
        axis.title.y = element_text(size=14, face="bold"),
        plot.title = element_text(size=16, face="bold", hjust=0.5),
        plot.margin = unit(c(0,0,0,0), "lines"),
        aspect.ratio = 3/4)+
  scale_fill_manual(values=cbp1)
dev.off()
